import pymysql
import pymysql.cursors
from DBUtils.PooledDB import PooledDB
from sshtunnel import SSHTunnelForwarder
from common.db.dbconfigs import *
'''
sql:将要执行的sql. 如果以select开头,有返回值;
config:将要连接的数据库,如果为空默认连接百度云
'''


def execute(sql, config, dictionary=False, connection=None, autocommit=True):
    if not connection:
        connection = _connection(config)

    results = _execute(sql, connection, dictionary)
    if autocommit:
        connection.commit()
    return results


def _execute(sql, connection, dictionary=False):
    cursorclass = pymysql.cursors.Cursor
    if dictionary:
        cursorclass = pymysql.cursors.DictCursor
    cursor = connection.cursor(cursorclass)

    cursor.execute(sql)
    results = None
    if sql.lstrip().upper().startswith("SELECT"):
        results = cursor.fetchall()
    cursor.close()
    return results


# 数据库连接池
db_pools = {}


def _connection(config):
    pool_key = str(config)
    if pool_key not in db_pools:
        '''
        1.maxcached确定了实际连接(真实和数据库连接,不是被pooleddb封装的那个)的数量
        验证: set @s=1. 当maxcached=1时,所有线程都可以拿到. 当maxcached=2时,有的可以拿到,有的拿不到.
        
        2.获取的连接关闭不关闭都一样, 关闭了也只是标识可以被其他线程使用. 
        实际使用体验是这样的: 这个连接好像不是同时只被一个线程使用,所有线程都可以同时使用它,
        所有一个线程不释放连接, 不影响其他线程使用.
        验证: cursor.close()之后不关闭连接, 也不退出方法,休眠10s,同时3个线程执行这个语句.
        最终在休眠3s后所有的线程同时执行完了.
        '''
        pool = PooledDB(creator=pymysql, mincached=1, maxcached=2, maxconnections=4, blocking=True,
                        user=config["user"],
                        passwd=config["passwd"],
                        host=config["host"],
                        db=config['db'],
                        port=config['port'],
                        charset=config['charset'],
                        )
        db_pools[pool_key] = pool

    return db_pools[pool_key].connection()

# def execute(sql, config=None, dictionary=False):
#     if not config:
#         config = default
#
#     cursorclass = pymysql.cursors.Cursor
#     if dictionary:
#         cursorclass = pymysql.cursors.DictCursor
#
#     if config.__contains__("ssh"):
#         with SSHTunnelForwarder(
#                 (config["ssh"]["host"], config["ssh"]["port"]),
#                 ssh_username=config["ssh"]["username"],
#                 ssh_pkey=config["ssh"]["pkey"],
#                 ssh_private_key_password=config["ssh"]["password"],
#                 remote_bind_address=(config["host"], config["port"]),
#                 local_bind_address=('0.0.0.0', 10022)
#         ) as tunnel:
#             connect = pymysql.connect(host='127.0.0.1',  # 此处必须是是127.0.0.1
#                                       port=10022,
#                                       user=config["user"],
#                                       passwd=config["passwd"],
#                                       db=config["db"],
#                                       charset=config["charset"],
#                                       cursorclass=cursorclass
#                                       )
#             return _execute(sql, connect)
